{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b076bd1a-b236-4fbc-953d-8295b25122ae",
   "metadata": {},
   "source": [
    "# 🎶 Music Generation with Transformers"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9235cbd1-f136-411c-88d9-f69f270c0b96",
   "metadata": {},
   "source": [
    "In this notebook, we'll walk through the steps required to train your own Transformer model to generate music in the style of the Bach cello suites"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84acc7be-6764-4668-b2bb-178f63deeed3",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import os\n",
    "import glob\n",
    "import numpy as np\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras import layers, models, losses, callbacks\n",
    "\n",
    "import music21\n",
    "\n",
    "from transformer_utils import (\n",
    "    parse_midi_files,\n",
    "    load_parsed_files,\n",
    "    get_midi_note,\n",
    "    SinePositionEncoding,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "339e6268-ebd7-4feb-86db-1fe7abccdbe5",
   "metadata": {},
   "source": [
    "## 0. Parameters <a name=\"parameters\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b2ee6ce-129f-4833-b0c5-fa567381c4e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "PARSE_MIDI_FILES = True\n",
    "PARSED_DATA_PATH = \"/app/notebooks/11_music/01_transformer/parsed_data/\"\n",
    "DATASET_REPETITIONS = 1\n",
    "\n",
    "SEQ_LEN = 50\n",
    "EMBEDDING_DIM = 256\n",
    "KEY_DIM = 256\n",
    "N_HEADS = 5\n",
    "DROPOUT_RATE = 0.3\n",
    "FEED_FORWARD_DIM = 256\n",
    "LOAD_MODEL = False\n",
    "\n",
    "# optimization\n",
    "EPOCHS = 5000\n",
    "BATCH_SIZE = 256\n",
    "\n",
    "GENERATE_LEN = 50"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d4f5e63-e36a-4dc8-9f03-cb29c1fa5290",
   "metadata": {},
   "source": [
    "## 1. Prepare the Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73de38bd-0b92-4441-9601-ed4a3b45f924",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the data\n",
    "file_list = glob.glob(\"/app/data/bach-cello/*.mid\")\n",
    "print(f\"Found {len(file_list)} midi files\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1ff575d-1632-43bf-844b-e5f2cea61454",
   "metadata": {},
   "outputs": [],
   "source": [
    "parser = music21.converter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b563da4f-0b08-4005-aa35-ca58a60a7def",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "example_score = (\n",
    "    music21.converter.parse(file_list[1]).splitAtQuarterLength(12)[0].chordify()\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "332bb057-d11d-41b3-9328-91dc7479a110",
   "metadata": {},
   "outputs": [],
   "source": [
    "example_score.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "debce40c-2d56-4140-b65b-459c6464c1a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "example_score.show(\"text\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77459313-2417-4c18-b938-3a9859ec9bd9",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "if PARSE_MIDI_FILES:\n",
    "    notes, durations = parse_midi_files(\n",
    "        file_list, parser, SEQ_LEN + 1, PARSED_DATA_PATH\n",
    "    )\n",
    "else:\n",
    "    notes, durations = load_parsed_files()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99f17b33-193e-4d83-8e0a-54dc8e7b249d",
   "metadata": {},
   "outputs": [],
   "source": [
    "example_notes = notes[658]\n",
    "example_durations = durations[658]\n",
    "print(\"\\nNotes string\\n\", example_notes, \"...\")\n",
    "print(\"\\nDuration string\\n\", example_durations, \"...\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aba2f39c-882f-423a-88eb-334502413639",
   "metadata": {},
   "source": [
    "## 2. Tokenize the data <a name=\"tokenize\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "244e093e-b977-4b1f-8f46-6503a55ea0fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_dataset(elements):\n",
    "    ds = (\n",
    "        tf.data.Dataset.from_tensor_slices(elements)\n",
    "        .batch(BATCH_SIZE, drop_remainder=True)\n",
    "        .shuffle(1000)\n",
    "    )\n",
    "    vectorize_layer = layers.TextVectorization(\n",
    "        standardize=None, output_mode=\"int\"\n",
    "    )\n",
    "    vectorize_layer.adapt(ds)\n",
    "    vocab = vectorize_layer.get_vocabulary()\n",
    "    return ds, vectorize_layer, vocab\n",
    "\n",
    "\n",
    "notes_seq_ds, notes_vectorize_layer, notes_vocab = create_dataset(notes)\n",
    "durations_seq_ds, durations_vectorize_layer, durations_vocab = create_dataset(\n",
    "    durations\n",
    ")\n",
    "seq_ds = tf.data.Dataset.zip((notes_seq_ds, durations_seq_ds))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ec4d36c-a4ad-4c32-89a2-749c21786441",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display the same example notes and durations converted to ints\n",
    "example_tokenised_notes = notes_vectorize_layer(example_notes)\n",
    "example_tokenised_durations = durations_vectorize_layer(example_durations)\n",
    "print(\"{:10} {:10}\".format(\"note token\", \"duration token\"))\n",
    "for i, (note_int, duration_int) in enumerate(\n",
    "    zip(\n",
    "        example_tokenised_notes.numpy()[:11],\n",
    "        example_tokenised_durations.numpy()[:11],\n",
    "    )\n",
    "):\n",
    "    print(f\"{note_int:10}{duration_int:10}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8dc29b17-1591-4b02-98e1-7f25960350be",
   "metadata": {},
   "outputs": [],
   "source": [
    "notes_vocab_size = len(notes_vocab)\n",
    "durations_vocab_size = len(durations_vocab)\n",
    "\n",
    "# Display some token:note mappings\n",
    "print(f\"\\nNOTES_VOCAB: length = {len(notes_vocab)}\")\n",
    "for i, note in enumerate(notes_vocab[:10]):\n",
    "    print(f\"{i}: {note}\")\n",
    "\n",
    "print(f\"\\nDURATIONS_VOCAB: length = {len(durations_vocab)}\")\n",
    "# Display some token:duration mappings\n",
    "for i, note in enumerate(durations_vocab[:10]):\n",
    "    print(f\"{i}: {note}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "823fb0c1-ebf8-453b-be94-9a33b466cae4",
   "metadata": {},
   "source": [
    "## 3. Create the Training Set <a name=\"create\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f0f2a52-b157-478d-8de7-11ed6c383461",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Create the training set of sequences and the same sequences shifted by one note\n",
    "def prepare_inputs(notes, durations):\n",
    "    notes = tf.expand_dims(notes, -1)\n",
    "    durations = tf.expand_dims(durations, -1)\n",
    "    tokenized_notes = notes_vectorize_layer(notes)\n",
    "    tokenized_durations = durations_vectorize_layer(durations)\n",
    "    x = (tokenized_notes[:, :-1], tokenized_durations[:, :-1])\n",
    "    y = (tokenized_notes[:, 1:], tokenized_durations[:, 1:])\n",
    "    return x, y\n",
    "\n",
    "\n",
    "ds = seq_ds.map(prepare_inputs).repeat(DATASET_REPETITIONS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78248965-1716-4077-8db4-191e98c7a6a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "example_input_output = ds.take(1).get_single_element()\n",
    "print(example_input_output)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5bb6376b-c4f9-4946-a736-94f1739c3149",
   "metadata": {},
   "source": [
    "## 5. Create the causal attention mask function <a name=\"causal\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afb6fa02-3b59-48eb-bcef-d51c9d32ec18",
   "metadata": {},
   "outputs": [],
   "source": [
    "def causal_attention_mask(batch_size, n_dest, n_src, dtype):\n",
    "    i = tf.range(n_dest)[:, None]\n",
    "    j = tf.range(n_src)\n",
    "    m = i >= j - n_src + n_dest\n",
    "    mask = tf.cast(m, dtype)\n",
    "    mask = tf.reshape(mask, [1, n_dest, n_src])\n",
    "    mult = tf.concat(\n",
    "        [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], 0\n",
    "    )\n",
    "    return tf.tile(mask, mult)\n",
    "\n",
    "\n",
    "np.transpose(causal_attention_mask(1, 10, 10, dtype=tf.int32)[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "52210a38-c8b4-4da7-90ce-eaa89c8fcafa",
   "metadata": {},
   "source": [
    "## 6. Create a Transformer Block layer <a name=\"transformer\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc3f25e8-4e3f-4849-9b92-676ea46e3ed2",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TransformerBlock(layers.Layer):\n",
    "    def __init__(\n",
    "        self,\n",
    "        num_heads,\n",
    "        key_dim,\n",
    "        embed_dim,\n",
    "        ff_dim,\n",
    "        name,\n",
    "        dropout_rate=DROPOUT_RATE,\n",
    "    ):\n",
    "        super(TransformerBlock, self).__init__(name=name)\n",
    "        self.num_heads = num_heads\n",
    "        self.key_dim = key_dim\n",
    "        self.embed_dim = embed_dim\n",
    "        self.ff_dim = ff_dim\n",
    "        self.dropout_rate = dropout_rate\n",
    "        self.attn = layers.MultiHeadAttention(\n",
    "            num_heads, key_dim, output_shape=embed_dim\n",
    "        )\n",
    "        self.dropout_1 = layers.Dropout(self.dropout_rate)\n",
    "        self.ln_1 = layers.LayerNormalization(epsilon=1e-6)\n",
    "        self.ffn_1 = layers.Dense(self.ff_dim, activation=\"relu\")\n",
    "        self.ffn_2 = layers.Dense(self.embed_dim)\n",
    "        self.dropout_2 = layers.Dropout(self.dropout_rate)\n",
    "        self.ln_2 = layers.LayerNormalization(epsilon=1e-6)\n",
    "\n",
    "    def call(self, inputs):\n",
    "        input_shape = tf.shape(inputs)\n",
    "        batch_size = input_shape[0]\n",
    "        seq_len = input_shape[1]\n",
    "        causal_mask = causal_attention_mask(\n",
    "            batch_size, seq_len, seq_len, tf.bool\n",
    "        )\n",
    "        attention_output, attention_scores = self.attn(\n",
    "            inputs,\n",
    "            inputs,\n",
    "            attention_mask=causal_mask,\n",
    "            return_attention_scores=True,\n",
    "        )\n",
    "        attention_output = self.dropout_1(attention_output)\n",
    "        out1 = self.ln_1(inputs + attention_output)\n",
    "        ffn_1 = self.ffn_1(out1)\n",
    "        ffn_2 = self.ffn_2(ffn_1)\n",
    "        ffn_output = self.dropout_2(ffn_2)\n",
    "        return (self.ln_2(out1 + ffn_output), attention_scores)\n",
    "\n",
    "    def get_config(self):\n",
    "        config = super().get_config()\n",
    "        config.update(\n",
    "            {\n",
    "                \"key_dim\": self.key_dim,\n",
    "                \"embed_dim\": self.embed_dim,\n",
    "                \"num_heads\": self.num_heads,\n",
    "                \"ff_dim\": self.ff_dim,\n",
    "                \"dropout_rate\": self.dropout_rate,\n",
    "            }\n",
    "        )\n",
    "        return config"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37ee1198-3bfc-415a-b247-02b8166133fe",
   "metadata": {},
   "source": [
    "## 7. Create the Token and Position Embedding <a name=\"embedder\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7e30225-7fea-4a3c-afbc-b0017d5da661",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TokenAndPositionEmbedding(layers.Layer):\n",
    "    def __init__(self, vocab_size, embed_dim):\n",
    "        super(TokenAndPositionEmbedding, self).__init__()\n",
    "        self.vocab_size = vocab_size\n",
    "        self.embed_dim = embed_dim\n",
    "        self.token_emb = layers.Embedding(\n",
    "            input_dim=vocab_size,\n",
    "            output_dim=embed_dim,\n",
    "            embeddings_initializer=\"he_uniform\",\n",
    "        )\n",
    "        self.pos_emb = SinePositionEncoding()\n",
    "\n",
    "    def call(self, x):\n",
    "        embedding = self.token_emb(x)\n",
    "        positions = self.pos_emb(embedding)\n",
    "        return embedding + positions\n",
    "\n",
    "    def get_config(self):\n",
    "        config = super().get_config()\n",
    "        config.update(\n",
    "            {\n",
    "                \"vocab_size\": self.vocab_size,\n",
    "                \"embed_dim\": self.embed_dim,\n",
    "            }\n",
    "        )\n",
    "        return config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96ec1c62-c375-4831-99f3-d313f98f39b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "tpe = TokenAndPositionEmbedding(notes_vocab_size, 32)\n",
    "token_embedding = tpe.token_emb(example_tokenised_notes)\n",
    "position_embedding = tpe.pos_emb(token_embedding)\n",
    "embedding = tpe(example_tokenised_notes)\n",
    "plt.imshow(\n",
    "    np.transpose(token_embedding),\n",
    "    cmap=\"coolwarm\",\n",
    "    interpolation=\"nearest\",\n",
    "    origin=\"lower\",\n",
    ")\n",
    "plt.show()\n",
    "plt.imshow(\n",
    "    np.transpose(position_embedding),\n",
    "    cmap=\"coolwarm\",\n",
    "    interpolation=\"nearest\",\n",
    "    origin=\"lower\",\n",
    ")\n",
    "plt.show()\n",
    "plt.imshow(\n",
    "    np.transpose(embedding),\n",
    "    cmap=\"coolwarm\",\n",
    "    interpolation=\"nearest\",\n",
    "    origin=\"lower\",\n",
    ")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fee70157-5cb1-466c-bb9d-ba4720a93173",
   "metadata": {},
   "source": [
    "## 8. Build the Transformer model <a name=\"transformer_decoder\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df56f070-228b-4364-94ae-5b6a12349960",
   "metadata": {},
   "outputs": [],
   "source": [
    "note_inputs = layers.Input(shape=(None,), dtype=tf.int32)\n",
    "durations_inputs = layers.Input(shape=(None,), dtype=tf.int32)\n",
    "note_embeddings = TokenAndPositionEmbedding(\n",
    "    notes_vocab_size, EMBEDDING_DIM // 2\n",
    ")(note_inputs)\n",
    "duration_embeddings = TokenAndPositionEmbedding(\n",
    "    durations_vocab_size, EMBEDDING_DIM // 2\n",
    ")(durations_inputs)\n",
    "embeddings = layers.Concatenate()([note_embeddings, duration_embeddings])\n",
    "x, attention_scores = TransformerBlock(\n",
    "    N_HEADS, KEY_DIM, EMBEDDING_DIM, FEED_FORWARD_DIM, name=\"attention\"\n",
    ")(embeddings)\n",
    "note_outputs = layers.Dense(\n",
    "    notes_vocab_size, activation=\"softmax\", name=\"note_outputs\"\n",
    ")(x)\n",
    "duration_outputs = layers.Dense(\n",
    "    durations_vocab_size, activation=\"softmax\", name=\"duration_outputs\"\n",
    ")(x)\n",
    "model = models.Model(\n",
    "    inputs=[note_inputs, durations_inputs],\n",
    "    outputs=[note_outputs, duration_outputs],  # attention_scores\n",
    ")\n",
    "model.compile(\n",
    "    \"adam\",\n",
    "    loss=[\n",
    "        losses.SparseCategoricalCrossentropy(),\n",
    "        losses.SparseCategoricalCrossentropy(),\n",
    "    ],\n",
    ")\n",
    "att_model = models.Model(\n",
    "    inputs=[note_inputs, durations_inputs], outputs=attention_scores\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eddbcc80-d0e9-41f0-bd70-032d489369b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b51a1dbc-7185-43ab-b9f9-834d267c99a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "if LOAD_MODEL:\n",
    "    model.load_weights(\"./checkpoint/checkpoint.ckpt\")\n",
    "    # model = models.load_model('./models/model', compile=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "afeab12c-b871-47c0-884c-752238b9f719",
   "metadata": {},
   "source": [
    "## 9. Train the Transformer <a name=\"train\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2003f19-610a-4acb-a225-0f6c1a3d3a1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a MusicGenerator checkpoint\n",
    "class MusicGenerator(callbacks.Callback):\n",
    "    def __init__(self, index_to_note, index_to_duration, top_k=10):\n",
    "        self.index_to_note = index_to_note\n",
    "        self.note_to_index = {\n",
    "            note: index for index, note in enumerate(index_to_note)\n",
    "        }\n",
    "        self.index_to_duration = index_to_duration\n",
    "        self.duration_to_index = {\n",
    "            duration: index for index, duration in enumerate(index_to_duration)\n",
    "        }\n",
    "\n",
    "    def sample_from(self, probs, temperature):\n",
    "        probs = probs ** (1 / temperature)\n",
    "        probs = probs / np.sum(probs)\n",
    "        return np.random.choice(len(probs), p=probs), probs\n",
    "\n",
    "    def get_note(self, notes, durations, temperature):\n",
    "        sample_note_idx = 1\n",
    "        while sample_note_idx == 1:\n",
    "            sample_note_idx, note_probs = self.sample_from(\n",
    "                notes[0][-1], temperature\n",
    "            )\n",
    "            sample_note = self.index_to_note[sample_note_idx]\n",
    "\n",
    "        sample_duration_idx = 1\n",
    "        while sample_duration_idx == 1:\n",
    "            sample_duration_idx, duration_probs = self.sample_from(\n",
    "                durations[0][-1], temperature\n",
    "            )\n",
    "            sample_duration = self.index_to_duration[sample_duration_idx]\n",
    "\n",
    "        new_note = get_midi_note(sample_note, sample_duration)\n",
    "\n",
    "        return (\n",
    "            new_note,\n",
    "            sample_note_idx,\n",
    "            sample_note,\n",
    "            note_probs,\n",
    "            sample_duration_idx,\n",
    "            sample_duration,\n",
    "            duration_probs,\n",
    "        )\n",
    "\n",
    "    def generate(self, start_notes, start_durations, max_tokens, temperature):\n",
    "        attention_model = models.Model(\n",
    "            inputs=self.model.input,\n",
    "            outputs=self.model.get_layer(\"attention\").output,\n",
    "        )\n",
    "\n",
    "        start_note_tokens = [self.note_to_index.get(x, 1) for x in start_notes]\n",
    "        start_duration_tokens = [\n",
    "            self.duration_to_index.get(x, 1) for x in start_durations\n",
    "        ]\n",
    "        sample_note = None\n",
    "        sample_duration = None\n",
    "        info = []\n",
    "        midi_stream = music21.stream.Stream()\n",
    "\n",
    "        midi_stream.append(music21.clef.BassClef())\n",
    "\n",
    "        for sample_note, sample_duration in zip(start_notes, start_durations):\n",
    "            new_note = get_midi_note(sample_note, sample_duration)\n",
    "            if new_note is not None:\n",
    "                midi_stream.append(new_note)\n",
    "\n",
    "        while len(start_note_tokens) < max_tokens:\n",
    "            x1 = np.array([start_note_tokens])\n",
    "            x2 = np.array([start_duration_tokens])\n",
    "            notes, durations = self.model.predict([x1, x2], verbose=0)\n",
    "\n",
    "            repeat = True\n",
    "\n",
    "            while repeat:\n",
    "                (\n",
    "                    new_note,\n",
    "                    sample_note_idx,\n",
    "                    sample_note,\n",
    "                    note_probs,\n",
    "                    sample_duration_idx,\n",
    "                    sample_duration,\n",
    "                    duration_probs,\n",
    "                ) = self.get_note(notes, durations, temperature)\n",
    "\n",
    "                if (\n",
    "                    isinstance(new_note, music21.chord.Chord)\n",
    "                    or isinstance(new_note, music21.note.Note)\n",
    "                    or isinstance(new_note, music21.note.Rest)\n",
    "                ) and sample_duration == \"0.0\":\n",
    "                    repeat = True\n",
    "                else:\n",
    "                    repeat = False\n",
    "\n",
    "            if new_note is not None:\n",
    "                midi_stream.append(new_note)\n",
    "\n",
    "            _, att = attention_model.predict([x1, x2], verbose=0)\n",
    "\n",
    "            info.append(\n",
    "                {\n",
    "                    \"prompt\": [start_notes.copy(), start_durations.copy()],\n",
    "                    \"midi\": midi_stream,\n",
    "                    \"chosen_note\": (sample_note, sample_duration),\n",
    "                    \"note_probs\": note_probs,\n",
    "                    \"duration_probs\": duration_probs,\n",
    "                    \"atts\": att[0, :, -1, :],\n",
    "                }\n",
    "            )\n",
    "            start_note_tokens.append(sample_note_idx)\n",
    "            start_duration_tokens.append(sample_duration_idx)\n",
    "            start_notes.append(sample_note)\n",
    "            start_durations.append(sample_duration)\n",
    "\n",
    "            if sample_note == \"START\":\n",
    "                break\n",
    "\n",
    "        return info\n",
    "\n",
    "    def on_epoch_end(self, epoch, logs=None):\n",
    "        info = self.generate(\n",
    "            [\"START\"], [\"0.0\"], max_tokens=GENERATE_LEN, temperature=0.5\n",
    "        )\n",
    "        midi_stream = info[-1][\"midi\"].chordify()\n",
    "        print(info[-1][\"prompt\"])\n",
    "        midi_stream.show()\n",
    "        midi_stream.write(\n",
    "            \"midi\",\n",
    "            fp=os.path.join(\n",
    "                \"/app/notebooks/11_music/01_transformer/output\",\n",
    "                \"output-\" + str(epoch).zfill(4) + \".mid\",\n",
    "            ),\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48a82971-e609-4ced-9fd1-60eb9694c08a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a model save checkpoint\n",
    "model_checkpoint_callback = callbacks.ModelCheckpoint(\n",
    "    filepath=\"./checkpoint/checkpoint.ckpt\",\n",
    "    save_weights_only=True,\n",
    "    save_freq=\"epoch\",\n",
    "    verbose=0,\n",
    ")\n",
    "\n",
    "tensorboard_callback = callbacks.TensorBoard(log_dir=\"./logs\")\n",
    "\n",
    "# Tokenize starting prompt\n",
    "music_generator = MusicGenerator(notes_vocab, durations_vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69f0c057-01c8-4a6a-9d82-8e83e66aa075",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "model.fit(\n",
    "    ds,\n",
    "    epochs=EPOCHS,\n",
    "    callbacks=[\n",
    "        model_checkpoint_callback,\n",
    "        tensorboard_callback,\n",
    "        music_generator,\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4de2131e-f13c-45af-ab8d-c361b0be8640",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save the final model\n",
    "model.save(\"./models/model\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7f0445d3-5513-428b-995a-b337547d2a71",
   "metadata": {},
   "source": [
    "# 3. Generate music using the Transformer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0776a68-d4c7-439b-a3fd-d9397b357662",
   "metadata": {},
   "outputs": [],
   "source": [
    "info = music_generator.generate(\n",
    "    [\"START\"], [\"0.0\"], max_tokens=50, temperature=0.5\n",
    ")\n",
    "midi_stream = info[-1][\"midi\"].chordify()\n",
    "midi_stream.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8932e58a-5a9c-44a3-bff4-39c564fafdf4",
   "metadata": {},
   "source": [
    "## Write music to MIDI file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2402314a-ca87-468e-a9d1-b79403fe7065",
   "metadata": {},
   "outputs": [],
   "source": [
    "timestr = time.strftime(\"%Y%m%d-%H%M%S\")\n",
    "midi_stream.write(\n",
    "    \"midi\",\n",
    "    fp=os.path.join(\n",
    "        \"/app/notebooks/11_music/01_transformer/output\",\n",
    "        \"output-\" + timestr + \".mid\",\n",
    "    ),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24eb14c3-9dab-42b9-8916-b9b3a4d38590",
   "metadata": {},
   "source": [
    "## Note probabilities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c92ad9f2-74ed-4432-b633-f74fd17b6a49",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_pitch = 70\n",
    "seq_len = len(info)\n",
    "grid = np.zeros((max_pitch, seq_len), dtype=np.float32)\n",
    "\n",
    "for j in range(seq_len):\n",
    "    for i, prob in enumerate(info[j][\"note_probs\"]):\n",
    "        try:\n",
    "            pitch = music21.note.Note(notes_vocab[i]).pitch.midi\n",
    "            grid[pitch, j] = prob\n",
    "        except:\n",
    "            pass  # Don't show key / time signatures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d4ce565-a5d4-4cec-b321-1e02d814b65c",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(8, 8))\n",
    "ax.set_yticks([int(j) for j in range(35, 70)])\n",
    "plt.imshow(\n",
    "    grid[35:70, :],\n",
    "    origin=\"lower\",\n",
    "    cmap=\"coolwarm\",\n",
    "    vmin=-0.5,\n",
    "    vmax=0.5,\n",
    "    extent=[0, seq_len, 35, 70],\n",
    ")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bdae49db-6e0f-4071-a289-74310d684fab",
   "metadata": {},
   "source": [
    "## Attention Plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf270a8f-da3f-4d21-b4f9-c19fcf7c4332",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_size = 20\n",
    "\n",
    "att_matrix = np.zeros((plot_size, plot_size))\n",
    "prediction_output = []\n",
    "last_prompt = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a3fcec1-f2fc-4dbc-b997-d00a958b6388",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "for j in range(plot_size):\n",
    "    atts = info[j][\"atts\"].max(axis=0)\n",
    "    att_matrix[: (j + 1), j] = atts\n",
    "    prediction_output.append(info[j][\"chosen_note\"][0])\n",
    "    last_prompt.append(info[j][\"prompt\"][0][-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b2742cf-ad18-4c23-8181-5d868ca03c0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(8, 8))\n",
    "im = ax.imshow(att_matrix, cmap=\"Greens\", interpolation=\"nearest\")\n",
    "\n",
    "ax.set_xticks(np.arange(-0.5, plot_size, 1), minor=True)\n",
    "ax.set_yticks(np.arange(-0.5, plot_size, 1), minor=True)\n",
    "ax.grid(which=\"minor\", color=\"black\", linestyle=\"-\", linewidth=1)\n",
    "ax.set_xticks(np.arange(plot_size))\n",
    "ax.set_yticks(np.arange(plot_size))\n",
    "ax.set_xticklabels(prediction_output[:plot_size])\n",
    "ax.set_yticklabels(last_prompt[:plot_size])\n",
    "ax.xaxis.tick_top()\n",
    "\n",
    "plt.setp(\n",
    "    ax.get_xticklabels(),\n",
    "    rotation=90,\n",
    "    ha=\"left\",\n",
    "    va=\"center\",\n",
    "    rotation_mode=\"anchor\",\n",
    ")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af9fd8cb-2877-4ba2-b15a-b8b5a6122c0c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "vscode": {
   "interpreter": {
    "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
